# Threshold Cellpose masks and calculate mask overlap

import os
import numpy as np
import pandas as pd
from importlib import reload

from pathlib import Path
from skimage import io

from matplotlib import pyplot as plt

import img_util.io
import img_util.util
import img_util.imgproc
import img_util.plot
import img_util.threshold

# Define IO paths
dropbox_basedir = 'C:/Users/timhc/Dropbox/Work/Kang Lab/Exp/2021-01-09 D2GFT32-35 LDR decay SKF Quin Fos/'
cellpose_mask_dir = dropbox_basedir + 'cellpose mask/for_thresh/'
img_basedir = dropbox_basedir + 'crop img/'
excl_roi_dir = dropbox_basedir + 'exclude ROI/'
img_save_basedir = dropbox_basedir + 'img_w_masks_for_pub/'
dataframe_save_dir = dropbox_basedir + 'quant mask/'
df_save_fname = 'D2GFT9-42_DA_antag_mask_quant.csv'

indi_subj_save_folder = True
save_thresholded_mask = False
thresholded_mask_save_dir = dropbox_basedir + 'thresholded mask/'

# Setup parameters for each channel as a dataframe
param = {'channel': ['AF405', 'AF488', 'AF555', 'AF680'],
         'cell diam': [12, 16, 20, 14],
         'flow threshold': [0.8, 0.6, 0.6, 0.6],
         'mask threshold': [-3, -1, -1, -1],
         'overlap threshold': [0.5, 0.5, 0.5, 0.5]}

mask_threshold_dict = {
    'AF405': {'threshold_func': img_util.threshold.threshold_mask_center_vs_outline,
             'mask_outline_pix': 3,
             'mask_center_pix': 0,
             'intst_thld': 10,
             'intst_thld_type': 'diff'},
    'AF488': {'threshold_func': img_util.threshold.threshold_mask_center_vs_outline,
              'mask_outline_pix': 3,
              'mask_center_pix': 0,
              'intst_thld': 10,
              'intst_thld_type': 'diff'},
    'AF555': {'threshold_func': img_util.threshold.threshold_mask_center_vs_outline,
              'mask_outline_pix': 3,
              'mask_center_pix': 3,
              'intst_thld': 40,
              'intst_thld_type': 'diff'},
    'AF680': {'threshold_func': img_util.threshold.threshold_mask_center_vs_outline,
              'mask_outline_pix': 3,
              'mask_center_pix': 0,
              'intst_thld': 20,
              'intst_thld_type': 'diff'}
}

df_channel = pd.DataFrame(param)
df_channel.set_index('channel', inplace=True)
print(df_channel)

process_channels = ['AF405', 'AF488', 'AF555', 'AF680']
bkgnd_roi_channel = 'AF488'
mask_paths, samples_id = img_util.io.confirm_mask_paths_get_sample_id(cellpose_mask_dir, process_channels)

# Background ROI settings
if bkgnd_roi_channel is not None:
    bkgnd_cell_diam = df_channel.loc[bkgnd_roi_channel]['cell diam']
    rm_bkgnd_iso_extra_factor = 1.5  # multiply area of background channel single cell by this factor for isolated area
    rm_bkgnd_iso_area = (bkgnd_cell_diam/2) ** 2 * np.pi * rm_bkgnd_iso_extra_factor  # isolated background area threshold
else:
    rm_bkgnd_iso_area = None

rm_bkgnd_small_holes = 100 ** 2 * np.pi

# Dataframe IO
mask_quant = []  # mask quantification for each sample, to be concatenated into dataframe
Path(dataframe_save_dir).mkdir(exist_ok=True)
dataframe_save_path = dataframe_save_dir + '/' + df_save_fname
if os.path.exists(dataframe_save_path):
    mask_quant.append(pd.read_csv(dataframe_save_path))

plt.ioff()
save_img = True
save_mask_overlap_img = True
for sample_i in samples_id:

    if indi_subj_save_folder:
        img_new_dir = sample_i.split('_')[0]  # directory for saving image
        img_save_dir = img_save_basedir + img_new_dir + '/'
    else:
        img_save_dir = img_save_basedir

    channel_mask_dict = {}
    channel_orig_img = {}  # dict to store original image, for saving mask overlap images

    for ch in process_channels:
        mask_path = img_util.io.find_mask_path_w_sample_id(mask_paths, sample_i, ch, mask_keyword='mask')
        orig_img = io.imread(img_util.io.get_orig_img_path(mask_path, img_basedir, img_ext='.png'))
        channel_orig_img.update({ch: orig_img})
        diameter = df_channel.loc[ch]['cell diam']
        mask_threshold_dict_ch = mask_threshold_dict[ch]

        bg_mask, masks_thresh, img_w_mask = img_util.imgproc.get_thresh_masks_w_bkgnd(
            mask_path, img_basedir, excl_roi_dir, bkgnd_roi_channel=bkgnd_roi_channel, img_ext='.png',
            bkgnd_rm_holes_area=rm_bkgnd_small_holes, bkgnd_open_close_size=3, bkgnd_colour='yellow',
            bkgnd_fill_alpha=0, cell_diameter=diameter, mask_threshold_dict=mask_threshold_dict_ch,
            rm_small_bg_area=rm_bkgnd_iso_area)

        output_dict = {'background mask': bg_mask, 'thresholded masks': masks_thresh,
                       'image with mask': img_w_mask}

        if save_img:
            Path(img_save_dir).mkdir(parents=True, exist_ok=True)
            img_save_fname = img_save_dir + sample_i + '_' + ch \
                             + '_mask.png'
            img_save_fname.replace('__', '_')  # tidy up file name
            img_save_fname.replace(' _', '_')  # tidy up file name
            io.imsave(img_save_fname, img_w_mask)

        if save_thresholded_mask:
            Path(thresholded_mask_save_dir).mkdir(parents=True, exist_ok=True)
            mask_save_fname = thresholded_mask_save_dir + sample_i + '_' + ch \
                              + '_thld_mask.png'
            mask_save_fname.replace('__', '_')  # tidy up file name
            mask_save_fname.replace(' _', '_')  # tidy up file name
            io.imsave(mask_save_fname, masks_thresh[0], check_contrast=False)

        channel_mask_dict.update({ch: output_dict})

    # Store quantification results in dataframe
    if bkgnd_roi_channel is not None:
        bkgnd_mask = channel_mask_dict[bkgnd_roi_channel]['background mask']
    else:
        bkgnd_mask = channel_mask_dict[process_channels[0]]['background mask']
    df_sample = pd.DataFrame({'sample_name': [sample_i],
                              'background_npix': bkgnd_mask.sum()})
    for ch in process_channels:
        df_sample[ch] = len(channel_mask_dict[ch]['thresholded masks'][1])  # number of unique masks

    df_overlap = img_util.util.quant_channels_mask_overlap(process_channels, channel_mask_dict, df_channel,
                                                           channel_orig_img=channel_orig_img,
                                                           save_mask_overlap_img=True,
                                                           img_save_dir=img_save_dir + 'mask_overlap/',
                                                           sample_i=sample_i)
    df_sample = pd.merge(df_sample, df_overlap, on='sample_name')
    mask_quant.append(df_sample)

    # Concatenate dataframe across samples. If there are duplicates, keep the last (newest) one
    df_mask_quant = pd.concat(mask_quant).drop_duplicates(subset=['sample_name'], keep='last'). \
        reset_index(drop=True)

    # Save dataframe to disk
    df_mask_quant.sort_values(by=['sample_name'], inplace=True)
    df_mask_quant.to_csv(dataframe_save_path, index=False)

